{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "%config InlineBackend.print_figure_kwargs = {'dpi':200, 'bbox_inches': 'tight'}\n",
    "from matplotlib_latex_configurations import *\n",
    "rcParams['figure.figsize'] = (one_column_figure_size * golden_ration, one_column_figure_size)\n",
    "\n",
    "import numpy as np\n",
    "import networkx as nx\n",
    "import scipy as sc\n",
    "import pylab as plt\n",
    "from tqdm import tqdm\n",
    "import scipy.interpolate as sci\n",
    "import networkx as nx\n",
    "import pickle\n",
    "import itertools\n",
    "marker = itertools.cycle(('p', '+', '.', 'o', '*', 's', 'v'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Phase transition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import math\n",
    "\n",
    "cases = 16\n",
    "times = np.logspace(-2, 2., 25)\n",
    "\n",
    "fig, ax2 = plt.subplots(figsize=(5,3))\n",
    "\n",
    "def plot_phase_transition(r, kappamin_avg, kappamin_std, ax2):\n",
    "\n",
    "    mean = np.array(kappamin_avg)\n",
    "    std = np.array(kappamin_std)\n",
    "    \n",
    "    ax2.plot(r, mean, '-o', marker = next(marker), markersize=7)\n",
    "    \n",
    "    ax2.fill_between(r, mean + std, mean - std, alpha=0.35)\n",
    "    \n",
    "    #ax2.axvline((c_ - np.sqrt(c_))/(c_ + np.sqrt(c_)), c='C0',lw=3,ls='--')\n",
    "    \n",
    "\n",
    "def compute_intersection(r, kappamin_avg, kappamin_std):\n",
    "    \n",
    "    mean = np.array(kappamin_avg)\n",
    "    std = np.array(kappamin_std)\n",
    "    \n",
    "    gplus = sci.interp1d(mean + std, r)\n",
    "    g = sci.interp1d(mean, r)\n",
    "    gminus = sci.interp1d(mean - std, r, fill_value=\"extrapolate\")\n",
    "    th = 0.2\n",
    "    \n",
    "    return [ gminus(th), g(th), gplus(th)]\n",
    "    \n",
    "\n",
    "def compute_separation(graph, kappa):\n",
    "    kappa_within = []\n",
    "    kappa_between = []\n",
    "    for i, e in enumerate(graph.edges):\n",
    "        if e[0]<n/2 and e[1]>=n/2:\n",
    "            kappa_between.append(kappa[:,i])\n",
    "        else:\n",
    "            kappa_within.append(kappa[:,i])\n",
    "        \n",
    "    kappa_within_mean = np.mean(kappa_within, axis=0)\n",
    "    kappa_between_mean = np.mean(kappa_between, axis=0)\n",
    "    kappa_within_var = np.var(kappa_within, axis=0)\n",
    "    kappa_between_var = np.var(kappa_between, axis=0)\n",
    "    \n",
    "    #f = sci.interp1d(times, kappa_within_mean)\n",
    "    #g = sci.interp1d(times, kappa_between_mean)\n",
    "    indx = np.where(kappa_within_mean>=0.75)[0][0]\n",
    "    \n",
    "    #plt.plot(np.log10(times), kappa_within_mean)\n",
    "    #plt.scatter(np.log10(times[indx]), kappa_within_mean[indx])\n",
    "    #import sys\n",
    "    #sys.exit()\n",
    "\n",
    "    return abs(kappa_within_mean[indx]- kappa_between_mean[indx])/np.sqrt(0.5*(kappa_within_var[indx] + kappa_within_var[indx]))\n",
    "\n",
    "\n",
    "def compute_separation2(lambda2, times, kappa):\n",
    "    \n",
    "    kappamean = 0\n",
    "    for e in range(kappa.shape[1]):\n",
    "        f = sci.interp1d(times, kappa[:,e])\n",
    "        kappamean += f(lambda2)\n",
    "\n",
    "    kappamean /= kappa.shape[1]\n",
    "    \n",
    "    return kappamean\n",
    "\n",
    "\n",
    "rc2 = []\n",
    "for c_ in c:\n",
    "\n",
    "    c_in = np.linspace(c_*0.5, c_*0.9, cases)\n",
    "    c_out = c_ - c_in\n",
    "    r = c_out/c_in\n",
    "    \n",
    "    a = pickle.load(open(\"/data/AG/geocluster/phase_transition/phase_transition_curvature_final_k\" + str(c_) + \"_200.pkl\", \"rb\"))\n",
    "\n",
    "    kappamin_avg = []\n",
    "    kappamin_std = []\n",
    "    zerocross_avg = []\n",
    "    for case in range(cases):\n",
    "        \n",
    "        #select successful trials\n",
    "        ntrials = len(a)\n",
    "        kappa = [a[trial][case][1] for trial in range(ntrials) if not np.isnan(a[trial][case][1]).any()]\n",
    "        graphs = [a[trial][case][0] for trial in range(ntrials) if not np.isnan(a[trial][case][1]).any()]\n",
    "        \n",
    "        kappamin = [compute_separation(graphs[i], kappa[i]) for i in range(len(kappa))] \n",
    "        #kappamin = [compute_separation2(lambda2, times, kappa[i]) for i in range(len(kappa))] \n",
    "                \n",
    "        #average over them\n",
    "        kappamin_avg.append(np.mean(kappamin))\n",
    "        kappamin_std.append(np.std(kappamin))\n",
    "       \n",
    "    plot_phase_transition(r, kappamin_avg, kappamin_std, ax2)\n",
    "    intersection = compute_intersection(r, kappamin_avg, kappamin_std)\n",
    "    rc2.append(intersection)\n",
    "    \n",
    "ax2.set_xlabel(r'Edge density ratio, $r$')\n",
    "ax2.set_ylabel(r'Sensitivity index, ')\n",
    "ax2.set_xlim([0, 1])\n",
    "ax2.set_ylim([0.01, 12])\n",
    "ax2.set_yscale('log')\n",
    "\n",
    "ax2.legend([r'$\\bar{k}=5$', r'$\\bar{k}=8$', r'$\\bar{k}=10$', r'$\\bar{k}=15$', r'$\\bar{k}=20$', r'$\\bar{k}=30$'])\n",
    "#plt.savefig('pt_curvature.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = np.linspace(0,50)\n",
    "rcrit = (k - np.sqrt(k))/(k + np.sqrt(k))\n",
    "\n",
    "plt.plot(k, rcrit)\n",
    "\n",
    "plt.errorbar(c, np.array(rc)[:,1], yerr=np.array(rc)[:,[0,2]].T/2, fmt='o')\n",
    "plt.errorbar(c, np.array(rc2)[:,1], yerr=np.array(rc2)[:,[0,2]].T/2, fmt='o')\n",
    "\n",
    "plt.xlabel(r'Mean degree, $\\overline{k}$')\n",
    "plt.ylabel(r'Critical edge density ratio, $r_{\\overline{k}}^*$')\n",
    "#plt.savefig('critical_density.svg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
